from otree.api import *
import random
import math

doc = """
PGG with judges + Norms & Empirical Expectations (Pre/Post).
"""

class C(BaseConstants):
    NAME_IN_URL = 'base'
    PLAYERS_PER_GROUP = None
    NUM_ROUNDS = 3  # Keep 3 for testing, change to 15 later
    ENDOWMENT = 20
    MPCR = 0.4

class Subsession(BaseSubsession):
    pass

class Group(BaseGroup):
    pass

class Player(BasePlayer):
    is_judge = models.BooleanField(initial=False)
    subgroup = models.IntegerField(initial=0)
    pgg_group = models.IntegerField(initial=0)
    pid_input = models.StringField(initial='')

    # PGG fields
    contribution = models.IntegerField(min=0, max=20, initial=0)
    payoff_pre = models.FloatField(initial=0.0)
    punishment_received = models.IntegerField(initial=0) 
    group_punishment_cost = models.FloatField(initial=0.0)

    # Judge Fields
    p1_punish = models.IntegerField(min=0, max=5, initial=0)
    p2_punish = models.IntegerField(min=0, max=5, initial=0)
    p3_punish = models.IntegerField(min=0, max=5, initial=0)
    p4_punish = models.IntegerField(min=0, max=5, initial=0)
    p5_punish = models.IntegerField(min=0, max=5, initial=0)
    p6_punish = models.IntegerField(min=0, max=5, initial=0)
    p7_punish = models.IntegerField(min=0, max=5, initial=0)
    p8_punish = models.IntegerField(min=0, max=5, initial=0)
    p9_punish = models.IntegerField(min=0, max=5, initial=0)
    p10_punish = models.IntegerField(min=0, max=5, initial=0)
    p11_punish = models.IntegerField(min=0, max=5, initial=0)
    p12_punish = models.IntegerField(min=0, max=5, initial=0)

    # --- Norms Pre Fields ---
    np_0 = models.IntegerField(label="Avg contrib 0:", min=0, max=20)
    np_5 = models.IntegerField(label="Avg contrib 5:", min=0, max=20)
    np_10 = models.IntegerField(label="Avg contrib 10:", min=0, max=20)
    np_15 = models.IntegerField(label="Avg contrib 15:", min=0, max=20)
    np_20 = models.IntegerField(label="Avg contrib 20:", min=0, max=20)

    nn_0 = models.IntegerField(label="Avg contrib 0:", min=0, max=20)
    nn_5 = models.IntegerField(label="Avg contrib 5:", min=0, max=20)
    nn_10 = models.IntegerField(label="Avg contrib 10:", min=0, max=20)
    nn_15 = models.IntegerField(label="Avg contrib 15:", min=0, max=20)
    nn_20 = models.IntegerField(label="Avg contrib 20:", min=0, max=20)

    expected_contribution = models.IntegerField(label="Avg contribution session:", min=0, max=20)

    # --- Norms Post Fields ---
    np_0_post = models.IntegerField(label="Avg contrib 0:", min=0, max=20)
    np_5_post = models.IntegerField(label="Avg contrib 5:", min=0, max=20)
    np_10_post = models.IntegerField(label="Avg contrib 10:", min=0, max=20)
    np_15_post = models.IntegerField(label="Avg contrib 15:", min=0, max=20)
    np_20_post = models.IntegerField(label="Avg contrib 20:", min=0, max=20)

    nn_0_post = models.IntegerField(label="Avg contrib 0:", min=0, max=20)
    nn_5_post = models.IntegerField(label="Avg contrib 5:", min=0, max=20)
    nn_10_post = models.IntegerField(label="Avg contrib 10:", min=0, max=20)
    nn_15_post = models.IntegerField(label="Avg contrib 15:", min=0, max=20)
    nn_20_post = models.IntegerField(label="Avg contrib 20:", min=0, max=20)

    expected_contribution_post = models.IntegerField(label="Avg contribution session:", min=0, max=20)

    # --- NEW FIELDS FOR DATA EXPORT & PAYMENT ---
    # These store the results of the norms calculations
    chosen_level_pre = models.IntegerField(initial=0)
    chosen_level_post = models.IntegerField(initial=0)
    
    avg_personal_norm_pre = models.FloatField(initial=0.0)
    avg_personal_norm_post = models.FloatField(initial=0.0)
    avg_contribution_session = models.FloatField(initial=0.0)
    
    bonus_norm_pre = models.CurrencyField(initial=0)
    bonus_norm_post = models.CurrencyField(initial=0)
    bonus_emp_pre = models.CurrencyField(initial=0)
    bonus_emp_post = models.CurrencyField(initial=0)

     # --- Payment Fields ---
    # Store which rounds were chosen (as a string "1, 5, 8")
    paid_rounds_str = models.StringField(initial="")
    
    # Store the sum of payoffs from those 3 rounds
    pgg_payoff_sum = models.CurrencyField(initial=0)
    
    # PGG in CZK (sum * 3)
    pgg_czk_total = models.CurrencyField(initial=0)
    
    # Total Norms Bonus (Pre + Post + Empirical)
    norms_total_bonus = models.CurrencyField(initial=0)
    
    # Final Payment (PGG CZK + Norms Bonus + Show-up fee if any)
    final_payment_czk = models.CurrencyField(initial=0)


###Testing norms - prefills numbers , comment for the real code

def creating_session(subsession: Subsession):
    # RUNS ONCE WHEN SESSION IS CREATED
    if subsession.round_number == 1:
        levels = [0, 5, 10, 15, 20]
        for p in subsession.get_players():
            # Pre norms
            for L in levels:
                setattr(p, f'np_{L}', random.randint(0, 20))
                setattr(p, f'nn_{L}', random.randint(0, 20))
            p.expected_contribution = random.randint(0, 20)
            
            # Post norms (initialize to 0 or random, doesn't matter much until end)
            for L in levels:
                setattr(p, f'np_{L}_post', random.randint(0, 20))
                setattr(p, f'nn_{L}_post', random.randint(0, 20))
            p.expected_contribution_post = random.randint(0, 20)

# =====================================================
# PAGES
# =====================================================

class IDInput(Page):
    form_model = 'player'
    form_fields = ['pid_input']
    @staticmethod
    def is_displayed(player: Player):
        return player.round_number == 1

class Intro(Page):
    @staticmethod
    def is_displayed(player: Player):
        return player.round_number == 1
    
    @staticmethod
    def vars_for_template(player: Player):
        return dict(treatment=player.session.config.get('treatment_type'))


class SetupWaitPage(WaitPage):
    @staticmethod
    def after_all_players_arrive(group: Group):
        subsession = group.subsession
        players = subsession.get_players()
        treatment = subsession.session.config.get('treatment_type')

        # 1. Round 1: Assign fixed judges and subgroups
        if subsession.round_number == 1:
            for p in players:
                p.participant.vars['is_judge'] = False
                p.participant.vars['subgroup'] = 0

            # ONLY human judge treatments get actual judge players
            if treatment in ['human_judge', 'humanAI_judge']:
                # Judges are Player 1 and Player 2
                for p in players:
                    if p.id_in_subsession == 1:
                        p.participant.vars['is_judge'] = True
                        p.participant.vars['subgroup'] = 1
                    elif p.id_in_subsession == 2:
                        p.participant.vars['is_judge'] = True
                        p.participant.vars['subgroup'] = 2

                normal_players = [p for p in players if not p.participant.vars['is_judge']]
                N = len(normal_players)
                
                # Split logic
                valid_splits = []
                for a in range(4, N, 4):
                    b = N - a
                    if b >= 4 and b % 4 == 0:
                        valid_splits.append((a, b))
                
                if not valid_splits:
                    if N % 8 == 0: a, b = N//2, N//2
                    else: raise ValueError(f"Cannot split {N} players into subgroups divisible by 4.")
                else:
                    a, b = min(valid_splits, key=lambda x: abs(x[0] - x[1]))

                random.shuffle(normal_players)
                for p in normal_players[:a]: p.participant.vars['subgroup'] = 1
                for p in normal_players[a:a+b]: p.participant.vars['subgroup'] = 2
            
            else: 
                # no_judge AND AI_judge: Everyone is normal, everyone in Subgroup 1
                for p in players: 
                    p.participant.vars['subgroup'] = 1

        # 2. Every Round: Reset values
        for p in players:
            p.is_judge = p.participant.vars.get('is_judge', False)
            p.subgroup = p.participant.vars.get('subgroup', 0)
            p.pgg_group = 0
            
            # Reset punishment fields
            p.punishment_received = 0
            p.group_punishment_cost = 0.0
            if p.is_judge:
                for i in range(1, 13):
                     setattr(p, f'p{i}_punish', 0)

        # 3. Every Round: Reshuffle PGG groups
        normal_players = [p for p in players if not p.is_judge]
        global_group_id = 1
        for sg in sorted(set(p.subgroup for p in normal_players)):
            sg_players = [p for p in normal_players if p.subgroup == sg]
            random.shuffle(sg_players)
            for i in range(0, len(sg_players), 4):
                for p in sg_players[i:i+4]:
                    p.pgg_group = global_group_id
                global_group_id += 1

class Instructions(Page):

    @staticmethod
    def is_displayed(player: Player):
        return player.round_number == 1
        return True

    @staticmethod
    def vars_for_template(player: Player):
        return dict(
            treatment=player.session.config.get('treatment_type'),
            is_judge=player.is_judge,
            endowment=C.ENDOWMENT,
            mpcr=C.MPCR,
        )

class NormsPersonal(Page):
    form_model = 'player'
    form_fields = ['np_0', 'np_5', 'np_10', 'np_15', 'np_20']
    @staticmethod
    def is_displayed(player: Player):
        return player.round_number == 1

class NormsNormative(Page):
    form_model = 'player'
    form_fields = ['nn_0', 'nn_5', 'nn_10', 'nn_15', 'nn_20']
    @staticmethod
    def is_displayed(player: Player):
        return player.round_number == 1

class NormsEmpirical(Page):
    form_model = 'player'
    form_fields = ['expected_contribution']
    @staticmethod
    def is_displayed(player: Player):
        return player.round_number == 1


class Cooperation(Page):
    form_model = 'player'
    form_fields = ['contribution']
    @staticmethod
    def is_displayed(player: Player):
        return not player.is_judge
    @staticmethod
    def vars_for_template(player: Player):
        return dict(endowment=C.ENDOWMENT, mpcr=C.MPCR)

class ResultsWaitPage(WaitPage):
    @staticmethod
    def after_all_players_arrive(group: Group):
        players = group.subsession.get_players()
        normal_players = [p for p in players if not p.is_judge]
        
        # 1. Calculate PGG outcomes
        pgg_ids = set(p.pgg_group for p in normal_players)
        for g_id in pgg_ids:
            members = [p for p in normal_players if p.pgg_group == g_id]
            total_contrib = sum(m.contribution for m in members)
            for m in members:
                m.payoff_pre = float(C.ENDOWMENT - m.contribution + (C.MPCR * total_contrib))
                m.payoff = math.ceil(m.payoff_pre)

        # 2. Calculate Judge Payoff (Average of Pre-Punishment Payoffs)
        judges = [p for p in players if p.is_judge]
        for j in judges:
            sg_players = [p for p in normal_players if p.subgroup == j.subgroup]
            if sg_players:
                avg_payoff = float(sum(p.payoff for p in sg_players) / len(sg_players))
                j.payoff = math.ceil(avg_payoff)
            else:
                j.payoff = 0

class JudgeWaitPage(WaitPage):
    @staticmethod
    def after_all_players_arrive(group: Group):
        subsession = group.subsession
        treatment = subsession.session.config.get('treatment_type')

        if treatment == 'no_judge':
            return

        # LOGIC FOR AI_judge (No human judges exist)
        if treatment == 'AI_judge':
            normal_players = subsession.get_players() # All players are normal
            subgroups = set(p.subgroup for p in normal_players)
            
            for sg in subgroups:
                sg_players = [p for p in normal_players if p.subgroup == sg]
                if not sg_players: continue
                
                avg_contrib = sum(p.contribution for p in sg_players) / len(sg_players)
                
                for p in sg_players:
                    deviation = avg_contrib - p.contribution
                    # Formula: min(5, max(0, round((avg - contrib) / 2)))
                    points = min(5, max(0, round(deviation / 2)))
                    p.punishment_received = points
            return

        # LOGIC FOR humanAI_judge (Pre-fill human judges)
        if treatment == 'humanAI_judge':
            judges = [p for p in subsession.get_players() if p.is_judge]
            
            for judge in judges:
                sg_players = [p for p in subsession.get_players() 
                              if not p.is_judge and p.subgroup == judge.subgroup]
                if not sg_players: continue

                avg_contrib = sum(p.contribution for p in sg_players) / len(sg_players)

                for i, p in enumerate(sg_players):
                    deviation = avg_contrib - p.contribution
                    points = min(5, max(0, round(deviation / 2)))
                    setattr(judge, f'p{i+1}_punish', points)

class Judge(Page):
    form_model = 'player'

    @staticmethod
    def is_displayed(player: Player):
        treatment = player.session.config.get('treatment_type')
        return player.is_judge and treatment in ['human_judge', 'humanAI_judge']

    @staticmethod
    def get_form_fields(player: Player):
        sg_players = [p for p in player.subsession.get_players() 
                      if not p.is_judge and p.subgroup == player.subgroup]
        return [f'p{i+1}_punish' for i in range(len(sg_players))]

    @staticmethod
    def vars_for_template(player: Player):
        sg_players = [p for p in player.subsession.get_players() 
                      if not p.is_judge and p.subgroup == player.subgroup]
        
        group_items = []
        for i, p in enumerate(sg_players):
            group_items.append({'player': p, 'field': f'p{i+1}_punish'})
            
        pgg_groups = {}
        for item in group_items:
            gid = item['player'].pgg_group
            pgg_groups.setdefault(gid, []).append(item)
            
        return dict(pgg_groups=pgg_groups)

    @staticmethod
    def before_next_page(player: Player, timeout_happened):
        sg_players = [p for p in player.subsession.get_players() 
                      if not p.is_judge and p.subgroup == player.subgroup]
        
        for i, p in enumerate(sg_players):
            val = getattr(player, f'p{i+1}_punish')
            p.punishment_received = val

class FinalWaitPage(WaitPage):
    @staticmethod
    def after_all_players_arrive(group: Group):
        players = group.subsession.get_players()
        normal_players = [p for p in players if not p.is_judge]

        # Apply Punishments
        pgg_ids = set(p.pgg_group for p in normal_players)
        for g_id in pgg_ids:
            members = [p for p in normal_players if p.pgg_group == g_id]
            
            total_punish = sum(m.punishment_received for m in members)
            shared_cost = total_punish / 4.0
            
            for m in members:
                m.group_punishment_cost = shared_cost
                final_val = float(m.payoff_pre - m.punishment_received - shared_cost)
                m.payoff = math.ceil(final_val)

class Results(Page):
    @staticmethod
    def vars_for_template(player: Player):
        if player.is_judge: return dict()

        group_members = [p for p in player.subsession.get_players() 
                         if not p.is_judge and p.pgg_group == player.pgg_group]
        
        anon_members = []
        for i, m in enumerate(group_members):
            anon_members.append({
                'label': chr(65+i), 
                'contribution': m.contribution, 
                'punishment': m.punishment_received
            })
            
        total_punish = sum(m.punishment_received for m in group_members)
        total_contribution = sum(m.contribution for m in group_members)
        return dict(
            anon_members=anon_members, 
            total_group_punish=total_punish,
            total_contribution=total_contribution
        )

class NormsPersonalPost(Page):
    form_model = 'player'
    form_fields = ['np_0_post', 'np_5_post', 'np_10_post', 'np_15_post', 'np_20_post']
    @staticmethod
    def is_displayed(player: Player):
        return player.round_number == C.NUM_ROUNDS

class NormsNormativePost(Page):
    form_model = 'player'
    form_fields = ['nn_0_post', 'nn_5_post', 'nn_10_post', 'nn_15_post', 'nn_20_post']
    @staticmethod
    def is_displayed(player: Player):
        return player.round_number == C.NUM_ROUNDS

class NormsEmpiricalPost(Page):
    form_model = 'player'
    form_fields = ['expected_contribution_post']
    @staticmethod
    def is_displayed(player: Player):
        return player.round_number == C.NUM_ROUNDS


class NormsAndPaymentWaitPage(WaitPage):
    @staticmethod
    def after_all_players_arrive(group: Group):
        subsession = group.subsession
        session = subsession.session

        if subsession.round_number != C.NUM_ROUNDS:
            return

        players_round1 = subsession.in_round(1).get_players()
        players_roundN = subsession.in_round(C.NUM_ROUNDS).get_players()

        # 1. Choose random levels
        levels = [0, 5, 10, 15, 20]
        pre_level = random.choice(levels)
        post_level = random.choice(levels)

        # 2. Calculate Averages (Personal Norms)
        # Pre
        pre_vals = [getattr(p, f'np_{pre_level}') for p in players_round1]
        avg_pre_personal = sum(pre_vals) / len(pre_vals) if pre_vals else 0
        
        # Post
        post_vals = [getattr(p, f'np_{post_level}_post') for p in players_roundN]
        avg_post_personal = sum(post_vals) / len(post_vals) if post_vals else 0

        # 3. Calculate Average Session Contribution (for Empirical)
        all_subsessions = session.get_subsessions()
        all_normal_players = []
        for ss in all_subsessions:
            all_normal_players.extend([p for p in ss.get_players() if not p.is_judge])
        
        total_contrib = sum(p.contribution for p in all_normal_players)
        avg_contrib = total_contrib / len(all_normal_players) if all_normal_players else 0

        # 4. Assign bonuses to PLAYERS (Save to DB)
        for pN in players_roundN:
            p1 = pN.in_round(1)

            # Store common session info in every player for easy export
            pN.chosen_level_pre = pre_level
            pN.chosen_level_post = post_level
            pN.avg_personal_norm_pre = avg_pre_personal
            pN.avg_personal_norm_post = avg_post_personal
            pN.avg_contribution_session = avg_contrib

            # --- BONUS 1: Pre Normative ---
            my_nn_pre = getattr(p1, f'nn_{pre_level}')
            if abs(my_nn_pre - avg_pre_personal) <= 3:
                pN.bonus_norm_pre = 50

            # --- BONUS 2: Pre Empirical ---
            my_emp_pre = p1.expected_contribution
            if abs(my_emp_pre - avg_contrib) <= 3:
                pN.bonus_emp_pre = 50

            # --- BONUS 3: Post Normative ---
            my_nn_post = getattr(pN, f'nn_{post_level}_post')
            if abs(my_nn_post - avg_post_personal) <= 3:
                pN.bonus_norm_post = 50

            # --- BONUS 4: Post Empirical ---
            my_emp_post = pN.expected_contribution_post
            if abs(my_emp_post - avg_contrib) <= 3:
                pN.bonus_emp_post = 50
        # ------------------------------------------------------------------
        # 5. PGG Payment: Random 3 Rounds
        # ------------------------------------------------------------------
        # Ensure we have enough rounds. If NUM_ROUNDS < 3, take all rounds.
        all_round_nums = list(range(1, C.NUM_ROUNDS + 1))
        
        # You need to pick rounds *per player* or *globally*?
        # Usually per player is fine, but oTree WaitPage runs once per group.
        # Let's pick 3 rounds RANDOMLY FOR EACH PLAYER (or same for all, up to you).
        # Standard: Random for each player.

        for pN in players_roundN:
            # 1. Pick 3 random rounds
            if C.NUM_ROUNDS >= 3:
                selected_rounds = random.sample(all_round_nums, 3)
            else:
                selected_rounds = all_round_nums # Take all if less than 3
            
            selected_rounds.sort() # e.g. [2, 5, 9]

            # 2. Calculate PGG Sum
            # Get the player's object for each of those rounds to read .payoff
            pgg_sum = 0
            for r_num in selected_rounds:
                pr = pN.in_round(r_num)
                pgg_sum += pr.payoff

            # 3. Calculate CZK (Multiplier = 3)
            pgg_czk = pgg_sum * 3

            # 4. Total Norms Bonus
            # (Assuming you set these fields in the previous norms block)
            norms_sum = (pN.bonus_norm_pre + pN.bonus_emp_pre + 
                         pN.bonus_norm_post + pN.bonus_emp_post)

            # 5. Final Total
            final_total = pgg_czk + norms_sum

            # 6. Save to DB (Player Fields)
            pN.paid_rounds_str = str(selected_rounds)
            pN.pgg_payoff_sum = pgg_sum
            pN.pgg_czk_total = pgg_czk
            pN.norms_total_bonus = norms_sum
            pN.final_payment_czk = final_total
            
            # Save to participant.payoff for oTree admin tracking (optional but good)
            pN.participant.payoff = final_total


class ResultsSummary(Page):
    @staticmethod
    def is_displayed(player: Player):
        return player.round_number == C.NUM_ROUNDS

    @staticmethod
    def vars_for_template(player: Player):
        # We now read directly from the PLAYER object, not participant.vars
        # This is safer and ensures we see what is in the DB.
        
        # Need to fetch round 1 player for Pre guesses
        p1 = player.in_round(1)
        
        pre_level = player.chosen_level_pre
        post_level = player.chosen_level_post
        
        return dict(
            # Pre Data
            norm_pre_level = pre_level,
            avg_pre_personal = round(player.avg_personal_norm_pre, 2),
            my_nn_pre = getattr(p1, f'nn_{pre_level}'),
            bonus_norm_pre = player.bonus_norm_pre,
            
            my_emp_pre = p1.expected_contribution,
            bonus_emp_pre = player.bonus_emp_pre,
            
            # Post Data
            norm_post_level = post_level,
            avg_post_personal = round(player.avg_personal_norm_post, 2),
            my_nn_post = getattr(player, f'nn_{post_level}_post'),
            bonus_norm_post = player.bonus_norm_post,
            
            my_emp_post = player.expected_contribution_post,
            bonus_emp_post = player.bonus_emp_post,
            
            # Common
            avg_contribution_session = round(player.avg_contribution_session, 2),
            
            # Total
            total_bonus = (player.bonus_norm_pre + player.bonus_emp_pre + 
                           player.bonus_norm_post + player.bonus_emp_post),

            # PGG Info
            paid_rounds_str = player.paid_rounds_str,
            pgg_payoff_sum = player.pgg_payoff_sum,
            pgg_czk_total = player.pgg_czk_total,
            
            # Totals
            norms_total_bonus = player.norms_total_bonus,
            final_payment_czk = player.final_payment_czk               
        )

page_sequence = [
    IDInput, Intro, SetupWaitPage, Instructions,
    NormsPersonal, NormsNormative, NormsEmpirical,
    Cooperation, ResultsWaitPage, JudgeWaitPage, Judge, FinalWaitPage, Results,
    NormsPersonalPost, NormsNormativePost, NormsEmpiricalPost,
    NormsAndPaymentWaitPage, ResultsSummary
]